import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
def reset_params():
params = {}
params['a'] = 270 #Hz/nA
params['b'] = 108 #Hz
params['d'] = 0.154 #s
params['gamma'] = 0.641
params['tau_s'] = 100/1000 #s (100 ms)
params['gE'] = 0.2609 #nA
params['gI'] = 0.0497 #nA
params['gext'] = 0.00052 #nA
params['I0'] = 0.3255 #nA
params['tau_0'] = 2/1000 #s (2 ms)
params['sig'] = 0.02 #nA
params['Tmax'] = 3 #s
params['del_t'] = 0.001 #s (1 ms)
params['ntrials'] = 10
params['len_T'] = int(params['Tmax']/params['del_t'])
return params
params = reset_params()
fdict = {}
fdict['stitle'] = 20
fdict['title'] = 16
fdict['axes'] = 14
def firing_rate(I_val, params):
a = params['a']
b = params['b']
d = params['d']
r = (a*I_val-b)/(1-np.exp(-d*(a*I_val-b)))
return r
def background_input():
tau_0 = params['tau_0']
I0 = params['I0']
sig = params['sig']
del_t = params['del_t']
len_T = params['len_T']
Ib = np.zeros((2, len_T))
for t in range(len_T-1):
for ii in range(2):
Ib[ii, t+1] = Ib[ii, t] + ((del_t/tau_0) * (-(Ib[ii, t] - I0) )) + (np.sqrt(del_t/tau_0)*sig*np.random.normal())
return Ib
def synaptic_dynamics(mu, params):
gamma = params['gamma']
tau_s = params['tau_s']
d = params['d']
gE = params['gE']
gI = params['gI']
gext = params['gext']
del_t = params['del_t']
ntrials = params['ntrials']
len_T = params['len_T']
s = 0.1 * np.ones((2,len_T, ntrials))
r = np.zeros((2,len_T, ntrials))
for ff in range(ntrials):
Ib = background_input()
for t in range(len_T - 1):
I = np.zeros(2,)
I[0] = gE * s[0, t, ff] - gI * s[1, t, ff] + Ib[0, t] + gext * mu[0, t]
I[1] = gE * s[1, t, ff] - gI * s[0, t, ff] + Ib[1, t] + gext * mu[1, t]
# if gE_stable == 'no':
# if t >= decay_start:
# gE -= 0.01
for ii in range(len(I)):
r[ii, t, ff] = firing_rate(I[ii], params)
del_s = (r[ii, t, ff] * gamma * (1-s[ii, t, ff])) - (s[ii, t, ff]/tau_s)
s[ii, t+1, ff] = s[ii, t, ff] + del_t * del_s
return s,r
def plot_fr_dynamics(r, s, stim, t_stim, qnum, params, fdict):
contrasts = [0.032, 0.064, 0.128, 0.256, 0.512, 0.85, 1]
if qnum == '1a':
tname = 'Resting state'
elif qnum == '1b1':
tname = 'Stimulus 1 on'
elif qnum == '1b2':
tname = 'Stimulus 2 on'
elif qnum == '1b3':
tname = 'Back to resting state'
elif qnum == '1d':
tname = 'Effect of distractor'
elif qnum == '1e':
tname = f'gE = {round(params["gE"], 3)}'
elif qnum == '2':
tname = 'Coin-tossing simulation for 10 trials'
elif qnum[0] == '3':
c_num = int(qnum[1])
tname = f'Coherence = {contrasts[c_num]}'
Tmax = params['Tmax']
del_t = params['del_t']
time = np.arange(0, Tmax, del_t)
ticker_range = np.arange(0, Tmax+0.1, 0.5)
t_stim = np.asarray(t_stim)
fig, axs = plt.subplots(1, 3, figsize = (20, 6))
axs[0].plot(time, stim[0, :], 'k')
axs[0].plot(time, stim[1, :], 'r')
axs[0].set_xlabel('Time (s)', fontsize = fdict['axes'])
axs[0].set_ylabel('Stimulus', fontsize = fdict['axes'])
axs[0].set_xticks(ticker_range)
axs[0].set_title('Input current', fontsize = fdict['title'])
axs[1].plot(time[:-1], s[0, :-1, :], 'k')
axs[1].plot(time[:-1], s[1, :-1, :], 'r')
for mm in range(t_stim.shape[0]):
axs[1].axvline(x = t_stim[mm, 0]*del_t, color='m', linestyle = '--')
axs[1].axvline(x = t_stim[mm, 1]*del_t, color='m', linestyle = '--')
axs[1].set_xlabel('Time (s)', fontsize = fdict['axes'])
axs[1].set_ylabel('Synaptic drive', fontsize = fdict['axes'])
axs[1].set_xticks(ticker_range)
axs[1].set_ylim([0, 1])
axs[1].set_title('s(t)', fontsize = fdict['title'])
axs[2].plot(time[:-1], r[0, :-1, :], 'k')
axs[2].plot(time[:-1], r[1, :-1, :], 'r')
for mm in range(t_stim.shape[0]):
axs[2].axvline(x = t_stim[mm, 0]*del_t, color='m', linestyle = '--')
axs[2].axvline(x = t_stim[mm, 1]*del_t, color='m', linestyle = '--')
axs[2].set_xlabel('Time (s)', fontsize = fdict['axes'])
axs[2].set_ylabel('Firing rate (Hz)', fontsize = fdict['axes'])
axs[2].set_xticks(ticker_range)
axs[2].set_title('r(t)', fontsize = fdict['title'])
plt.suptitle(tname, fontsize = fdict['stitle'])
plt.show()
fig, axs = plt.subplots(1, 2, figsize = (20, 6))
p1 = axs[0].scatter(s[0, :-1, 0], s[1, :-1, 0], s = 5, c = time[:-1], cmap = 'plasma')
axs[0].set_xlabel('S1', fontsize = fdict['axes'])
axs[0].set_ylabel('S2', fontsize = fdict['axes'])
axs[0].set_xlim([0, 1])
axs[0].set_ylim([0, 1])
axs[0].set_title('State space s(t)', fontsize = fdict['title'])
max_r = np.ceil(np.max(r))
p2 = axs[1].scatter(r[0, :-1, 0], r[1, :-1, 0], s = 5, c = time[:-1], cmap = 'plasma')
axs[1].set_xlabel('Firing rate 1 (Hz)', fontsize = fdict['axes'])
axs[1].set_ylabel('Firing rate 2 (Hz)', fontsize = fdict['axes'])
axs[1].set_xlim([0, max_r])
axs[1].set_ylim([0, max_r])
axs[1].set_title('State space r(t)', fontsize = fdict['title'])
cbar = fig.colorbar(p2)
plt.show()
mu = np.zeros((2, params['len_T']))
t_stim = [[0, params['len_T']]]
s,r = synaptic_dynamics(mu, params)
print(f'Mean resting state firing rate for population 1 is {round(np.mean(r[0,:, :]), 2)} Hz')
print(f'Mean resting state firing rate for population 2 is {round(np.mean(r[1,:, :]), 2)} Hz')
plot_fr_dynamics(r, s, mu, t_stim, '1a', params, fdict)
Mean resting state firing rate for population 1 is 2.1 Hz Mean resting state firing rate for population 2 is 2.16 Hz
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 800]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = 35
s,r = synaptic_dynamics(mu, params)
print(f'Mean resting state firing rate for population 1 is {round(np.mean(r[0,t_stim[0][1]:, :]), 2)} Hz')
print(f'Mean resting state firing rate for population 2 is {round(np.mean(r[1,t_stim[0][1]:, :]), 2)} Hz')
plot_fr_dynamics(r, s, mu, t_stim, '1b1', params, fdict)
Mean resting state firing rate for population 1 is 18.85 Hz Mean resting state firing rate for population 2 is 0.68 Hz
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 800]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[1, tt] = 35
s,r = synaptic_dynamics(mu, params)
print(f'Mean resting state firing rate for population 1 is {round(np.mean(r[0,t_stim[0][1]:, :]), 2)} Hz')
print(f'Mean resting state firing rate for population 2 is {round(np.mean(r[1,t_stim[0][1]:, :]), 2)} Hz')
plot_fr_dynamics(r, s, mu, t_stim, '1b2', params, fdict)
Mean resting state firing rate for population 1 is 0.69 Hz Mean resting state firing rate for population 2 is 18.25 Hz
params = reset_params()
params['Tmax'] = 5
params['len_T'] = int(params['Tmax']/params['del_t'])
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 800], [3000, 3300]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = 35
for tt in range(t_stim[1][0], t_stim[1][1]):
mu[1, tt] = 85
s,r = synaptic_dynamics(mu, params)
plot_fr_dynamics(r, s, mu, t_stim, '1b3', params, fdict)
params['Tmax'] = 3
params['len_T'] = int(params['Tmax']/params['del_t'])
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 800],[1800, 2100]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = 35
for tt in range(t_stim[1][0], t_stim[1][1]):
mu[1, tt] = 35
s,r = synaptic_dynamics(mu, params)
plot_fr_dynamics(r, s, mu, t_stim, '1d', params, fdict)
params = reset_params()
for blurb in range(15):
params['gE'] -= 0.001
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 800]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = 35
s,r = synaptic_dynamics(mu, params)
plot_fr_dynamics(r, s, mu, t_stim, '1e', params, fdict)
params = reset_params()
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 1500]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = 30
mu[1, tt] = 30
s,r = synaptic_dynamics(mu, params)
choice = np.zeros((params['ntrials'],), dtype = 'int')
for trial in range(params['ntrials']):
if r[0, -2, trial] < r[1, -2, trial]:
choice[trial] = 1
elif r[0, -2, trial] > r[1, -2, trial]:
choice[trial] = 2
print(f'Number of trials for which choice is 1 is {np.sum(choice == 1)}')
print(f'Number of trials for which choice is 2 is {np.sum(choice == 2)}')
plot_fr_dynamics(r, s, mu, t_stim, '2', params, fdict)
Number of trials for which choice is 1 is 5 Number of trials for which choice is 2 is 5
params = reset_params()
params['ntrials'] = 500
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 1500]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = 30
mu[1, tt] = 30
s,r = synaptic_dynamics(mu, params)
choice = np.zeros((params['ntrials'],), dtype = 'int')
for trial in range(params['ntrials']):
if r[0, -2, trial] > r[1, -2, trial]:
choice[trial] = 1
elif r[0, -2, trial] < r[1, -2, trial]:
choice[trial] = 2
print(f'Number of trials for which choice is 1 is {np.sum(choice == 1)}')
print(f'Number of trials for which choice is 2 is {np.sum(choice == 2)}')
Number of trials for which choice is 1 is 261 Number of trials for which choice is 2 is 239
coherence = [0.032, 0.064, 0.128, 0.256, 0.512, 0.85, 1]
coh_100 = 100 * np.asarray(coherence)
percentage_correct = np.zeros(len(coherence),)
mu_0 = 30
params = reset_params()
params['ntrials'] = 100
for lol in range(len(coherence)):
c = coherence[lol]
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 1500]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = mu_0 * (1 + c)
mu[1, tt] = mu_0 * (1 - c)
s,r = synaptic_dynamics(mu, params)
choice = np.zeros((params['ntrials'],), dtype = 'int')
for trial in range(params['ntrials']):
if r[0, -2, trial] > r[1, -2, trial]:
choice[trial] = 1
elif r[0, -2, trial] < r[1, -2, trial]:
choice[trial] = 2
percentage_correct[lol] = 100 * (np.sum(choice == 1)/params['ntrials'])
plt.figure(figsize = (8, 6))
plt.plot(np.log(coherence), percentage_correct, 'ro-')
plt.xlabel('% Coherence', fontsize = fdict['axes'])
plt.ylabel('% Correct', fontsize = fdict['axes'])
plt.title('Psychometric Function', fontsize = fdict['stitle'])
plt.xticks(ticks = np.log(coherence), labels = coh_100, rotation=20)
plt.show()
coherence = [0.032, 0.064, 0.128, 0.256, 0.512, 0.85, 1]
coh_100 = 100 * np.asarray(coherence)
percentage_correct = np.zeros(len(coherence),)
mu_0 = 30
f_thresh = 15
params = reset_params()
params['ntrials'] = 50
Tmax = params['Tmax']
del_t = params['del_t']
time = np.arange(0, Tmax, del_t)
RT = np.zeros((params['ntrials'],len(coherence)))
for lol in range(len(coherence)):
c = coherence[lol]
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 1500]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = mu_0 * (1 + c)
mu[1, tt] = mu_0 * (1 - c)
s,r = synaptic_dynamics(mu, params)
plt_str = f'3{lol}'
plot_fr_dynamics(r, s, mu, t_stim, plt_str, params, fdict)
choice = np.zeros((params['ntrials'],), dtype = 'int')
for trial in range(params['ntrials']):
thresh_idx0 = np.where(r[0,:,trial] >= f_thresh)[0]
thresh_idx1 = np.where(r[1,:,trial] >= f_thresh)[0]
if len(thresh_idx0) == 0:
thresh_idx0 = [params['len_T']]
if len(thresh_idx1) == 0:
thresh_idx1 = [params['len_T']]
if thresh_idx0[0] < thresh_idx1[0]:
choice[trial] = 1
RT[trial, lol] = thresh_idx0[0] * del_t
elif thresh_idx0[0] > thresh_idx1[0]:
choice[trial] = 2
RT[trial, lol] = thresh_idx1[0] * del_t
percentage_correct[lol] = 100 * (np.sum(choice == 1)/params['ntrials'])
fig, axs = plt.subplots(1, 2, figsize = (20, 6))
axs[0].plot(np.log(coherence), np.mean(RT, axis = 0), 'ro-')
axs[0].set_xlabel('% Coherence', fontsize = fdict['axes'])
axs[0].set_ylabel('Mean RT (s)', fontsize = fdict['axes'])
axs[0].set_title('Mean Reaction Time', fontsize = fdict['title'])
axs[0].set_xticks(ticks = np.log(coherence), labels = coh_100, rotation=20)
axs[1].plot(np.log(coherence), np.std(RT, axis = 0), 'ro-')
axs[1].set_xlabel('% Coherence', fontsize = fdict['axes'])
axs[1].set_ylabel('std RT (s)', fontsize = fdict['axes'])
axs[1].set_title('Stdev Reaction Time', fontsize = fdict['title'])
axs[1].set_xticks(ticks = np.log(coherence), labels = coh_100, rotation=20)
plt.show()
coherence = [0.032, 0.064, 0.128, 0.256, 0.512, 0.85, 1]
coh_100 = 100 * np.asarray(coherence)
mu_0 = 30
f_thresh = 15
params = reset_params()
params['ntrials'] = 30
Tmax = params['Tmax']
del_t = params['del_t']
time = np.arange(0, Tmax, del_t)
dur_t = [100, 300, 500, 800]
col_arr = ['r', 'k', 'b', 'm', 'g', 'y', 'c']
RT = np.zeros((params['ntrials'],len(coherence),len(dur_t)))
percentage_correct = np.zeros((len(coherence),len(dur_t)))
fig, axs = plt.subplots(1, 2, figsize = (20, 6))
for lol in range(len(coherence)):
c = coherence[lol]
for rofl in range(len(dur_t)):
mu = np.zeros((2, params['len_T']))
t_stim = [[500, 500 + dur_t[rofl]]]
for tt in range(t_stim[0][0], t_stim[0][1]):
mu[0, tt] = mu_0 * (1 + c)
mu[1, tt] = mu_0 * (1 - c)
s,r = synaptic_dynamics(mu, params)
choice = np.zeros((params['ntrials'],), dtype = 'int')
for trial in range(params['ntrials']):
thresh_idx0 = np.where(r[0,:,trial] >= f_thresh)[0]
thresh_idx1 = np.where(r[1,:,trial] >= f_thresh)[0]
if len(thresh_idx0) == 0:
thresh_idx0 = [params['len_T']]
if len(thresh_idx1) == 0:
thresh_idx1 = [params['len_T']]
if thresh_idx0[0] < thresh_idx1[0]:
choice[trial] = 1
RT[trial, lol, rofl] = thresh_idx0[0] * del_t
elif thresh_idx0[0] > thresh_idx1[0]:
choice[trial] = 2
RT[trial, lol, rofl] = thresh_idx1[0] * del_t
percentage_correct[lol, rofl] = 100 * (np.sum(choice == 1)/params['ntrials'])
axs[0].plot(dur_t, np.mean(RT[:,lol,:], axis = 0), col_arr[lol], linestyle = '--', marker = 'o', label = f'coh = {coh_100[lol]}')
axs[1].plot(dur_t, percentage_correct[lol,:], col_arr[lol], linestyle = '--', marker = 'o', label = f'coh = {coh_100[lol]}')
axs[0].set_xlabel('Stimulus Duration (ms)', fontsize = fdict['axes'])
axs[0].set_ylabel('Mean RT (s)', fontsize = fdict['axes'])
axs[0].set_title('Chronometric curve', fontsize = fdict['title'])
axs[1].set_xlabel('Stimulus Duration (ms)', fontsize = fdict['axes'])
axs[1].set_ylabel('% Correct', fontsize = fdict['axes'])
axs[1].set_title('Psychometric curve', fontsize = fdict['title'])
plt.legend()
plt.show()